# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
# Copyright 2021 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================


from npu_bridge.npu_init import *
import tensorflow as tf
from stnet import spatial_transformer_network as transformer
import numpy as np
import argparse
from tf_utils import weight_variable, bias_variable, dense_to_one_hot
import os
import time
from npu_bridge.estimator.npu.npu_loss_scale_optimizer import NPULossScaleOptimizer
from npu_bridge.estimator.npu.npu_loss_scale_manager import FixedLossScaleManager
from npu_bridge.estimator.npu.npu_loss_scale_manager import ExponentialUpdateLossScaleManager
import logging

# %% Load data

parser = argparse.ArgumentParser(description="train mnist",
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--data_path', type=str, default='s3://stnet/dataset', help='the training data')
parser.add_argument('--train_epochs', type=int, default=500,
                    help='max num of epochs')
parser.add_argument('--output_path', type=str, default='s3://stnet/stnet-npu',
                    help='the path model saved')
parser.add_argument('--modelarts', type=bool, default=True,
                    help='if the platform is modelarts')
args, unkown = parser.parse_known_args()
train_file = os.path.join(args.data_path + '/mnist_sequence1_sample_5distortions5x5.npz')
ops_info_file = 'ops_info.json'
mnist_cluttered = np.load(train_file, allow_pickle=True)
e2e_start_time = time.time()
X_train = mnist_cluttered['X_train']
y_train = mnist_cluttered['y_train']
X_valid = mnist_cluttered['X_valid']
y_valid = mnist_cluttered['y_valid']
X_test = mnist_cluttered['X_test']
y_test = mnist_cluttered['y_test']

# % turn from dense to one hot representation
Y_train = dense_to_one_hot(y_train, n_classes=10)
Y_valid = dense_to_one_hot(y_valid, n_classes=10)
Y_test = dense_to_one_hot(y_test, n_classes=10)

# %% Graph representation of our network

# %% Placeholders for 40x40 resolution
x = tf.placeholder(tf.float32, [None, 1600])
y = tf.placeholder(tf.float32, [None, 10])

# %% Since x is currently [batch, height*width], we need to reshape to a
# 4-D tensor to use it in a convolutional graph.  If one component of
# `shape` is the special value -1, the size of that dimension is
# computed so that the total size remains constant.  Since we haven't
# defined the batch dimension's shape yet, we use -1 to denote this
# dimension should not change size.
x_tensor = tf.reshape(x, [-1, 40, 40, 1])

# %% We'll setup the two-layer localisation network to figure out the
# %% parameters for an affine transformation of the input
# %% Create variables for fully connected layer
W_fc_loc1 = weight_variable([1600, 20])
b_fc_loc1 = bias_variable([20])

W_fc_loc2 = weight_variable([20, 6])
# Use identity transformation as starting point
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32')
initial = initial.flatten()
b_fc_loc2 = tf.Variable(initial_value=initial, name='b_fc_loc2')

# %% Define the two layer localisation network
h_fc_loc1 = tf.nn.tanh(tf.matmul(x, W_fc_loc1) + b_fc_loc1)
# %% We can add dropout for regularizing and to reduce overfitting like so:
keep_prob = tf.placeholder(tf.float32)
h_fc_loc1_drop = npu_ops.dropout(h_fc_loc1, keep_prob)
# %% Second layer
h_fc_loc2 = tf.nn.tanh(tf.matmul(h_fc_loc1_drop, W_fc_loc2) + b_fc_loc2)

# %% We'll create a spatial transformer module to identify discriminative
# %% patches
out_size = (40, 40)
h_trans = transformer(x_tensor, h_fc_loc2, out_size)

# %% We'll setup the first convolutional layer
# Weight matrix is [height x width x input_channels x output_channels]
filter_size = 3
n_filters_1 = 16
W_conv1 = weight_variable([filter_size, filter_size, 1, n_filters_1])

# %% Bias is [output_channels]
b_conv1 = bias_variable([n_filters_1])

# %% Now we can build a graph which does the first layer of convolution:
# we define our stride as batch x height x width x channels
# instead of pooling, we use strides of 2 and more layers
# with smaller filters.

h_conv1 = tf.nn.relu(
    tf.nn.conv2d(input=h_trans,
                 filter=W_conv1,
                 strides=[1, 2, 2, 1],
                 padding='SAME') +
    b_conv1)

# %% And just like the first layer, add additional layers to create
# a deep net
n_filters_2 = 16
W_conv2 = weight_variable([filter_size, filter_size, n_filters_1, n_filters_2])
b_conv2 = bias_variable([n_filters_2])
h_conv2 = tf.nn.relu(
    tf.nn.conv2d(input=h_conv1,
                 filter=W_conv2,
                 strides=[1, 2, 2, 1],
                 padding='SAME') +
    b_conv2)

# %% We'll now reshape so we can connect to a fully-connected layer:
h_conv2_flat = tf.reshape(h_conv2, [-1, 10 * 10 * n_filters_2])

# %% Create a fully-connected layer:
n_fc = 1024
W_fc1 = weight_variable([10 * 10 * n_filters_2, n_fc])
b_fc1 = bias_variable([n_fc])
h_fc1 = tf.nn.relu(tf.matmul(h_conv2_flat, W_fc1) + b_fc1)

h_fc1_drop = npu_ops.dropout(h_fc1, keep_prob)

# %% And finally our softmax layer:
W_fc2 = weight_variable([n_fc, 10])
b_fc2 = bias_variable([10])
y_logits = tf.matmul(h_fc1_drop, W_fc2) + b_fc2

# %% Define loss/eval/training functions
cross_entropy = tf.reduce_mean(
    tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_logits, labels=y))
opt = npu_tf_optimizer(tf.train.AdamOptimizer())

loss_scale_manager = ExponentialUpdateLossScaleManager(init_loss_scale=2 ** 32, incr_every_n_steps=1000,
                                                       decr_every_n_nan_or_inf=2, decr_ratio=0.5)
optimizer_temp = NPULossScaleOptimizer(opt, loss_scale_manager)
optimizer = optimizer_temp.minimize(cross_entropy)
grads = opt.compute_gradients(cross_entropy, [b_fc_loc2])

# %% Monitor accuracy
correct_prediction = tf.equal(tf.argmax(y_logits, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, 'float'))

# %% We now create a new session to actually perform the initialization the
# variables:
saver = tf.train.Saver()
config_proto = tf.ConfigProto()
custom_op = config_proto.graph_options.rewrite_options.custom_optimizers.add()
custom_op.name = "NpuOptimizer"
custom_op.parameter_map["mix_compile_mode"].b = True
custom_op.parameter_map["use_off_line"].b = True
custom_op.parameter_map["precision_mode"].s = tf.compat.as_bytes("allow_mix_precision")
custom_op.parameter_map["modify_mixlist"].s = tf.compat.as_bytes("ops_info.json")
config_proto.graph_options.rewrite_options.remapping = RewriterConfig.OFF
# # custom_op.parameter_map['dump_path'].s = tf.compat.as_bytes(saveDir + '/')
# # set dump debug
# custom_op.parameter_map["use_off_line"].b = True
# # custom_op.parameter_map['enable_dump_debug'].b = True
# # custom_op.parameter_map['dump_debug_mode'].s = tf.compat.as_bytes('all')
# custom_op.parameter_map["profiling_mode"].b = True
# custom_op.parameter_map["profiling_options"].s = tf.compat.as_bytes('{"output":"/cache/profiling",'
#                                                                     '"training_trace":"on",'
#                                                                     '"task_trace":"on","'
#                                                                     'fp_point":"random_normal/mul",'
#                                                                     '"bp_point":"gradients_1/add_1_grad/Sum_1"}')
# custom_op.parameter_map['precision_mode'].s = tf.compat.as_bytes(
#     'allow_mix_precision')
config = npu_config_proto(config_proto=config_proto)
sess = tf.Session(config=config)
sess.run(tf.global_variables_initializer())
# saveFile = saver.save(sess, sessFileNameTst, latest_filename='checkpointTst')

# %% We'll now train in minibatches and report accuracy, loss:
iter_per_epoch = 100
n_epochs = args.train_epochs
train_size = 10000

indices = np.linspace(0, 10000 - 1, iter_per_epoch)
indices = indices.astype('int')

for epoch_i in range(n_epochs):
    for iter_i in range(iter_per_epoch - 1):
        start_time=time.time()
        batch_xs = X_train[indices[iter_i]:indices[iter_i + 1]]
        batch_ys = Y_train[indices[iter_i]:indices[iter_i + 1]]

        if iter_i % 10 == 0:
            loss = sess.run(cross_entropy,
                            feed_dict={
                                x: batch_xs,
                                y: batch_ys,
                                keep_prob: 1.0
                            })
            #print('Iteration: ' + str(iter_i) + ' Loss: ' + str(loss))

        sess.run(optimizer, feed_dict={
            x: batch_xs, y: batch_ys, keep_prob: 0.8})
        cost_time = time.time() - start_time
        print("epoch : {}----step : {}----loss : {}----sec/step : {}".format(epoch_i, iter_i, loss, cost_time))
acc = sess.run(accuracy,
               feed_dict={
                   x: X_valid,
                   y: Y_valid,
                   keep_prob: 1.0
               })
print('accuracy : {}'.format(acc))
saver.save(sess, os.path.join(args.output_path, "model.ckpt"))
e2e_cost_time = time.time() - e2e_start_time
with open(os.path.join(args.output_path, "performance_precision.txt"), "w") as file_write:
    write_str = "Final Accuracy accuracy : " + str(round(acc, 4))
    print(str(write_str))
    file_write.write(write_str)
    file_write.write('\r\n')

    write_str = "Final Performance ms/step : " + str(round(cost_time * 1000, 4))
    print(str(write_str))
    file_write.write(write_str)
    file_write.write('\r\n')

    write_str = "Final Training Duration sec : " + str(round(e2e_cost_time, 4))
    print(str(write_str))
    file_write.write(write_str)
    file_write.write('\r\n')
end_time = time.time()
print('Training completed in second', (end_time - start_time))
print('One Epoch completed in second', (end_time - start_time) / n_epochs)
print('****************************************************')
sess.close()